#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from ... import Mapping
from ..exceptions import UnknownParsingFunctionError
from .stix2_observable_objects_converter import (
    ExternalSTIX2SampleObservableConverter,
    InternalSTIX2SampleObservableConverter)
from .stix2converter import (
    ExternalSTIX2Converter, InternalSTIX2Converter, STIX2Converter,
    _MAIN_PARSER_TYPING)
from .stix2mapping import (
    ExternalSTIX2Mapping, InternalSTIX2Mapping, STIX2Mapping)
from abc import ABCMeta
from pymisp import MISPGalaxy, MISPGalaxyCluster, MISPObject
from stix2.v20.sdo import Malware as Malware_v20
from stix2.v21.sdo import Malware as Malware_v21
from typing import Iterator, Optional, TYPE_CHECKING, Union

if TYPE_CHECKING:
    from ..external_stix2_to_misp import ExternalSTIX2toMISPParser
    from ..internal_stix2_to_misp import InternalSTIX2toMISPParser

_MALWARE_TYPING = Union[
    Malware_v20, Malware_v21
]


class STIX2MalwareMapping(metaclass=ABCMeta):
    __malware_type_attribute = {
        'type': 'text', 'object_relation': 'malware_type'
    }
    __malware_meta_mapping = Mapping(
        aliases='synonyms',
        architecture_execution_envs='architecture_execution_envs',
        capabilities='capabilities',
        first_seen='first_seen',
        implementation_languages='implementation_languages',
        is_family='is_family',
        last_seen='last_seen',
        malware_types='malware_types'
    )
    __malware_object_mapping = Mapping(
        aliases=STIX2Mapping.alias_attribute(),
        architecture_execution_envs={
            'type': 'text', 'object_relation': 'architecture_execution_env'
        },
        capabilities={'type': 'text', 'object_relation': 'capability'},
        description=STIX2Mapping.description_attribute(),
        first_seen={'type': 'datetime', 'object_relation': 'first_seen'},
        implementation_languages={
            'type': 'text', 'object_relation': 'implementation_language'
        },
        is_family={'type': 'boolean', 'object_relation': 'is_family'},
        labels=__malware_type_attribute,
        last_seen={'type': 'datetime', 'object_relation': 'last_seen'},
        malware_types=__malware_type_attribute,
        name=STIX2Mapping.name_attribute()
    )

    @classmethod
    def malware_meta_mapping(cls) -> dict:
        return cls.__malware_meta_mapping

    @classmethod
    def malware_object_mapping(cls) -> dict:
        return cls.__malware_object_mapping


class STIX2MalwareConverter(STIX2Converter, metaclass=ABCMeta):
    def __init__(self, main: _MAIN_PARSER_TYPING):
        self._set_main_parser(main)

    def _convert_malware_object(self, malware: _MALWARE_TYPING) -> MISPObject:
        malware_object = self._create_misp_object('malware', malware)
        for attribute in self._generic_parser(malware):
            malware_object.add_attribute(**attribute)
        return self.main_parser._add_misp_object(malware_object, malware)

    def _parse_malware_object(self, malware: _MALWARE_TYPING):
        malware_object = self._convert_malware_object(malware)
        if hasattr(malware, 'operating_system_refs'):
            for OS_ref in malware.operating_system_refs:
                software = self._converter._parse_software_observable(
                    OS_ref, malware
                )
                malware_object.add_reference(software.uuid, 'executable-on')
        if hasattr(malware, 'sample_refs'):
            for sample_ref in malware.sample_refs:
                feature = getattr(
                    self._converter,
                    f"_parse_{sample_ref.split('--')[0]}_observable_object"
                )
                sample = feature(sample_ref, malware)
                sample.add_reference(malware_object.uuid, 'sample-of')

    # Error handling
    def _malware_error(self, malware_id: str, exception: Exception):
        _traceback = self.main_parser._parse_traceback(exception)
        self.main_parser._add_error(
            'Error while parsing the Malware object with id '
            f'{malware_id}: {_traceback}'
        )


class ExternalSTIX2MalwareMapping(STIX2MalwareMapping, ExternalSTIX2Mapping):
    pass


class ExternalSTIX2MalwareConverter(
        STIX2MalwareConverter, ExternalSTIX2Converter):
    def __init__(self, main: 'ExternalSTIX2toMISPParser'):
        super().__init__(main)
        self._mapping = ExternalSTIX2MalwareMapping
        self._converter = ExternalSTIX2SampleObservableConverter(self)

    def parse(self, malware_ref: str):
        malware = self.main_parser._get_stix_object(malware_ref)
        try:
            if hasattr(malware, 'sample_refs'):
                if malware.is_family or self.main_parser.force_contextual_data:
                    self._parse_malware_object_and_galaxy(malware)
                else:
                    self._parse_malware_object(malware)
            elif getattr(malware, 'is_family', True):
                self._parse_galaxy(malware)
            else:
                if self.main_parser.force_contextual_data:
                    self._parse_galaxy(malware)
                self._parse_malware_object(malware)
        except Exception as exception:
            self._malware_error(malware.id, exception)

    def _convert_malware_objects(
            self, malware: _MALWARE_TYPING) -> Iterator[MISPObject]:
        malware_object = self._convert_malware_object(malware)
        if hasattr(malware, 'operating_system_refs'):
            for OS_ref in malware.operating_system_refs:
                software = self._converter._parse_software_observable(
                    OS_ref, malware
                )
                malware_object.add_reference(software.uuid, 'executable-on')
                yield software
        if hasattr(malware, 'sample_refs'):
            for sample_ref in malware.sample_refs:
                parser = getattr(
                    self._converter,
                    f"_parse_{sample_ref.split('--')[0]}_observable_object"
                )
                sample = parser(sample_ref, malware)
                sample.add_reference(malware_object.uuid, 'associated-with')
                yield sample
        yield malware_object

    def _create_cluster(self, malware: _MALWARE_TYPING,
                        galaxy_type: Optional[str] = None) -> MISPGalaxyCluster:
        malware_args = self._create_cluster_args(malware, galaxy_type)
        meta = self._handle_meta_fields(malware)
        if hasattr(malware, 'external_references'):
            meta.update(
                self._handle_external_references(malware.external_references)
            )
        if hasattr(malware, 'kill_chain_phases'):
            meta['kill_chain'] = self._handle_kill_chain_phases(
                malware.kill_chain_phases
            )
        if hasattr(malware, 'labels'):
            self._handle_labels(meta, malware.labels)
        if meta:
            malware_args['meta'] = meta
        return self.main_parser._create_misp_galaxy_cluster(**malware_args)

    def _parse_malware_object_and_galaxy(self, malware: _MALWARE_TYPING):
        self._parse_galaxy(malware)
        parsed = self.main_parser._clusters[malware.id]
        if self.main_parser.galaxies_as_tags:
            for misp_object in self._convert_malware_objects(malware):
                for attribute in misp_object.attributes:
                    if attribute.to_ids:
                        for tag_name in parsed['tag_names']:
                            attribute.add_tag(tag_name)
        else:
            cluster = parsed['cluster']
            misp_galaxy = MISPGalaxy()
            misp_galaxy.from_dict(**self.main_parser._galaxies[cluster.type])
            misp_galaxy.add_galaxy_cluster(**cluster)
            parsed['used'][self.event_uuid] = True
            for misp_object in self._convert_malware_objects(malware):
                for attribute in misp_object.attributes:
                    if attribute.to_ids:
                        attribute.add_galaxy(misp_galaxy)


class InternalSTIX2MalwareMapping(STIX2MalwareMapping, InternalSTIX2Mapping):
    __script_object_mapping = Mapping(
        name=STIX2Mapping.filename_attribute(),
        description=InternalSTIX2Mapping.comment_text_attribute(),
        implementation_languages=STIX2Mapping.language_attribute(),
        x_misp_script=InternalSTIX2Mapping.script_attribute(),
        x_misp_state=InternalSTIX2Mapping.state_attribute()
    )

    @classmethod
    def script_object_mapping(cls) -> dict:
        return cls.__script_object_mapping


class InternalSTIX2MalwareConverter(
        STIX2MalwareConverter, InternalSTIX2Converter):
    def __init__(self, main: 'InternalSTIX2toMISPParser'):
        super().__init__(main)
        self._mapping = InternalSTIX2MalwareMapping
        self._converter = InternalSTIX2SampleObservableConverter(self)

    def parse(self, malware_ref: str):
        malware = self.main_parser._get_stix_object(malware_ref)
        feature = self._handle_mapping_from_labels(malware.labels, malware.id)
        try:
            parser = getattr(self, feature)
        except AttributeError:
            raise UnknownParsingFunctionError(feature)
        try:
            parser(malware)
        except Exception as exception:
            self._malware_error(malware.id, exception)

    def _create_cluster(self, malware: _MALWARE_TYPING,
                        description: Optional[str] = None,
                        galaxy_type: Optional[str] = None) -> MISPGalaxyCluster:
        malware_args = self._create_cluster_args(
            malware, galaxy_type, description=description
        )
        meta = self._handle_meta_fields(malware)
        if hasattr(malware, 'external_references'):
            meta.update(
                self._handle_external_references(malware.external_references)
            )
        if meta.get('external_id'):
            self._handle_cluster_value_with_synonyms(malware_args, meta)
        if hasattr(malware, 'kill_chain_phases'):
            meta['kill_chain'] = self._handle_kill_chain_phases(
                malware.kill_chain_phases
            )
        if hasattr(malware, 'labels'):
            self._handle_labels(meta, malware.labels)
        if meta:
            malware_args['meta'] = meta
        return self.main_parser._create_misp_galaxy_cluster(**malware_args)

    def _parse_script_object(self, malware: _MALWARE_TYPING):
        misp_object = self._create_misp_object('script', malware)
        for attribute in self._generic_parser(malware, feature='script'):
            misp_object.add_attribute(**attribute)
        if hasattr(malware, 'x_misp_script_as_attachment'):
            attribute = {
                'type': 'attachment',
                'object_relation': 'script-as-attachment'
            }
            if isinstance(malware.x_misp_script_as_attachment, dict):
                attribute.update(malware.x_misp_script_as_attachment)
            else:
                attribute['value'] = malware.x_misp_script_as_attachment
            misp_object.add_attribute(**attribute)
        self.main_parser._add_misp_object(misp_object, malware)
